{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create a Distillation Pipeline to Distill DeepSeek-R1 into Qwen model with NeNo 2.0 Framework\n",
    "\n",
    "In the field of LLMs, reasoning models leverage deep thinking capabilities to significantly enhance model performance across complex scenarios. According to the [DeepSeek-R1](https://arxiv.org/abs/2501.12948) paper, the reasoning pattern of larger models can be distilled into smaller models. Specifically, we can distill long-chain-of-thought (long-CoT) data that includes reasoning processes from DeepSeek-R1 and directly fine-tune open-source models like Qwen and Llama. This straightforward distillation method significantly enhances the reasoning abilities of smaller models.\n",
    "\n",
    "To demonstrate the complete distillation process, we have prepared two notebooks that cover how to distill reasoning data from DeepSeek-R1 using the NIM API, and how to train models using the distilled data.\n",
    "\n",
    "\n",
    "- [generate_reasoning_data.ipynb](./generate_reasoning_data.ipynb) demonstrates how to distill reasoning data from DeepSeek-R1 using the NIM API. \n",
    "- [qwen2_distill_nemo.ipynb](./qwen2_distill_nemo.ipynb) (⭐) shows how to train open-source models using the distilled data.\n",
    "\n",
    "\n",
    "\n",
    "This notebook is part 2 of the series, and it demonstrates how to distill the reasoning ability of DeepSeek-R1 into the Qwen model using the NeMo Framework. The training is based on a custom dataset [Bespoke-Stratos-17k](https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k) distilled from deepseek-R1, which includes 17k reasoning chains in mathematics, code, Olympiads, science, and puzzles. You can also use the [generate_reasoning_data.ipynb](./generate_reasoning_data.ipynb) to generate your own reasoning data.\n",
    "\n",
    "\n",
    "## Preparation\n",
    "We use NeMo 2.0 and NeMo-Run to fine-tune (SFT) the Qwen model. You can start and enter the container `nvcr.io/nvidia/nemo:25.02`.\n",
    "\n",
    "\n",
    "## Step-By-Step Instructions\n",
    "\n",
    "This notebook contains five steps:\n",
    "\n",
    "1. Download the Qwen model and convert it to NeMo 2.0 format.\n",
    "2. Prepare the fine-tuning data.\n",
    "3. Fine-tune the Qwen model with NeMo 2.0 and NeMo-Run.\n",
    "4. Evaluate the model.\n",
    "5. Convert the output model from NeMo 2.0 to HF format."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Download Qwen and Convert to NeMo 2.0 Format\n",
    "\n",
    "First, download the Qwen-2.5-7B-Instruct model from the Hugging Face model hub and convert it to NeMo format.\n",
    "\n",
    "We use the `llm.import_ckpt` API to download the model using `hf://<huggingface_model_id>` and convert it to NeMo 2.0 format.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "import nemo_run as run\n",
    "import lightning.pytorch as pl\n",
    "from nemo.collections import llm\n",
    "\n",
    "# llm.import_ckpt is the nemo2 API for converting Hugging Face checkpoint to NeMo format\n",
    "# example python usage:\n",
    "# llm.import_ckpt(model=llm.llama3_8b.model(), source=\"hf://meta-llama/Meta-Llama-3-8B\")\n",
    "#\n",
    "# We use run.Partial to configure this function\n",
    "def configure_checkpoint_conversion():\n",
    "    return run.Partial(\n",
    "        llm.import_ckpt,\n",
    "        model=llm.qwen2_7b.model(),\n",
    "        source=\"hf://Qwen/Qwen2.5-7B-Instruct\",\n",
    "        overwrite=True,\n",
    "    )\n",
    "\n",
    "# configure your function\n",
    "import_ckpt = configure_checkpoint_conversion()\n",
    "# define your executor\n",
    "local_executor = run.LocalExecutor()\n",
    "\n",
    "# run your experiment\n",
    "run.run(import_ckpt, executor=local_executor)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Prepare Fine-tuning Data\n",
    "\n",
    "[Bespoke-Stratos-17k](https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k) is an open-souced dataset created by Bespoke Labs. It contains 17k reasoning questions from math, code, Olympiads, science and puzzle areas, which is distilled from DeepSeek-R1.\n",
    "\n",
    "Each case includes a question, a reasoning chain, and an answer. There is an example:\n",
    "\n",
    "```\n",
    "{\n",
    "    \"system\": \"Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with '\\\\n\\\\n'} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|> Now, try to solve the following question through the above guidelines:\",\n",
    "    \"conversations\": [\n",
    "      {\n",
    "        \"from\": \"user\",\n",
    "        \"value\": \"Return your final response within \\\\boxed{}. Two counterfeit coins of equal weight are mixed with $8$ identical genuine coins. The weight of each of the counterfeit coins is different from the weight of each of the genuine coins. A pair of coins is selected at random without replacement from the $10$ coins. A second pair is selected at random without replacement from the remaining $8$ coins. The combined weight of the first pair is equal to the combined weight of the second pair. What is the probability that all $4$ selected coins are genuine?\\n$\\\\textbf{(A)}\\\\ \\\\frac{7}{11}\\\\qquad\\\\textbf{(B)}\\\\ \\\\frac{9}{13}\\\\qquad\\\\textbf{(C)}\\\\ \\\\frac{11}{15}\\\\qquad\\\\textbf{(D)}\\\\ \\\\frac{15}{19}\\\\qquad\\\\textbf{(E)}\\\\ \\\\frac{15}{16}$\"\n",
    "      },\n",
    "      {\n",
    "        \"from\": \"assistant\",\n",
    "        \"value\": \"<|begin_of_thought|>\\n\\nOkay, so I need to solve this probability problem. Let me read it again and make sure I understand what's being asked. \\n\\nThere are two counterfeit coins that have the same weight, and they're mixed with 8 genuine coins. ...\n",
    "        ...the probability that all four selected coins are genuine is \\\\(\\\\boxed{D}\\\\).\\n\\n<|end_of_solution|>\"\n",
    "      }\n",
    "    ],\n",
    "  },\n",
    "  ```\n",
    "\n",
    "  To use this dataset, we need to write a custom dataloader \"BespokeDataModule\" to load the data from the dataset and support chat format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "%%writefile bespoke.py\n",
    "\n",
    "import json\n",
    "import shutil\n",
    "from typing import TYPE_CHECKING, Any, Dict, List, Optional\n",
    "\n",
    "# import numpy as np\n",
    "from datasets import load_dataset\n",
    "\n",
    "from nemo.collections.llm.gpt.data.core import get_dataset_root\n",
    "from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule\n",
    "from nemo.lightning.io.mixin import IOMixin\n",
    "from nemo.utils import logging\n",
    "\n",
    "from functools import lru_cache\n",
    "\n",
    "from nemo.collections.llm.gpt.data.core import create_sft_dataset\n",
    "\n",
    "if TYPE_CHECKING:\n",
    "    from nemo.collections.common.tokenizers import TokenizerSpec\n",
    "    from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs\n",
    "\n",
    "\n",
    "class BespokeDataModule(FineTuningDataModule, IOMixin):\n",
    "    \"\"\"A data module for fine-tuning on the Bespoke dataset.\n",
    "\n",
    "    This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models on the\n",
    "    \"bespokelabs/Bespoke-Stratos-17k\" dataset. It handles data download, preprocessing, splitting, and preparing the data\n",
    "    in a format suitable for training, validation, and testing.\n",
    "\n",
    "    Args:\n",
    "        force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally. Defaults to False.\n",
    "        delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing. Defaults to True.\n",
    "        See FineTuningDataModule for the other args\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        seq_length: int = 2048,\n",
    "        tokenizer: Optional[\"TokenizerSpec\"] = None,\n",
    "        micro_batch_size: int = 4,\n",
    "        global_batch_size: int = 8,\n",
    "        rampup_batch_size: Optional[List[int]] = None,\n",
    "        force_redownload: bool = False,\n",
    "        delete_raw: bool = True,\n",
    "        seed: int = 1234,\n",
    "        memmap_workers: int = 1,\n",
    "        num_workers: int = 8,\n",
    "        pin_memory: bool = True,\n",
    "        persistent_workers: bool = False,\n",
    "        packed_sequence_specs: Optional[\"PackedSequenceSpecs\"] = None,\n",
    "        dataset_kwargs: Optional[Dict[str, Any]] = None,\n",
    "        dataset_root: str = \"./bespoke\",\n",
    "    ):\n",
    "        self.force_redownload = force_redownload\n",
    "        self.delete_raw = delete_raw\n",
    "\n",
    "        super().__init__(\n",
    "            dataset_root=dataset_root,\n",
    "            seq_length=seq_length,\n",
    "            tokenizer=tokenizer,\n",
    "            micro_batch_size=micro_batch_size,\n",
    "            global_batch_size=global_batch_size,\n",
    "            rampup_batch_size=rampup_batch_size,\n",
    "            seed=seed,\n",
    "            memmap_workers=memmap_workers,\n",
    "            num_workers=num_workers,\n",
    "            pin_memory=pin_memory,\n",
    "            persistent_workers=persistent_workers,\n",
    "            packed_sequence_specs=packed_sequence_specs,\n",
    "            dataset_kwargs=dataset_kwargs,\n",
    "        )\n",
    "\n",
    "    def prepare_data(self) -> None:\n",
    "        # if train file is specified, no need to do anything\n",
    "        if not self.train_path.exists() or self.force_redownload:\n",
    "            dset = self._download_data()\n",
    "            self._preprocess_and_split_data(dset)\n",
    "        super().prepare_data()\n",
    "\n",
    "    def _download_data(self):\n",
    "        logging.info(f\"Downloading {self.__class__.__name__}...\")\n",
    "        return load_dataset(\n",
    "            \"bespokelabs/Bespoke-Stratos-17k\",\n",
    "            cache_dir=str(self.dataset_root),\n",
    "            download_mode=\"force_redownload\" if self.force_redownload else None,\n",
    "        )\n",
    "\n",
    "    def _preprocess_and_split_data(self, dset, train_ratio: float = 0.80, val_ratio: float = 0.15):\n",
    "        logging.info(f\"Preprocessing {self.__class__.__name__} to jsonl format and splitting...\")\n",
    "\n",
    "        test_ratio = 1 - train_ratio - val_ratio\n",
    "        save_splits = {}\n",
    "        dataset = dset.get('train')\n",
    "        split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed)\n",
    "        split_dataset2 = split_dataset['test'].train_test_split(\n",
    "            test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed\n",
    "        )\n",
    "        save_splits['training'] = split_dataset['train']\n",
    "        save_splits['validation'] = split_dataset2['train']\n",
    "        save_splits['test'] = split_dataset2['test']\n",
    "\n",
    "        print(\"len training: \", len(save_splits['training']))\n",
    "        print(\"len validation: \", len(save_splits['validation']))\n",
    "        print(\"len test: \", len(save_splits['test']))\n",
    "\n",
    "        for split_name, dataset in save_splits.items():\n",
    "            output_file = self.dataset_root / f\"{split_name}.jsonl\"\n",
    "            with output_file.open(\"w\", encoding=\"utf-8\") as f:\n",
    "                for example in dataset:\n",
    "                    \n",
    "                    conversations = example[\"conversations\"]\n",
    "\n",
    "                    for conversation in conversations:\n",
    "                        if conversation[\"from\"] == \"user\":\n",
    "                            conversation[\"from\"] = \"User\"\n",
    "                        elif conversation[\"from\"] == \"assistant\":\n",
    "                            conversation[\"from\"] = \"Assistant\"\n",
    "                        else:\n",
    "                            raise ValueError(f\"Unknown role: {conversation['role']}\")\n",
    "\n",
    "                    example[\"mask\"] = \"User\"\n",
    "                    example[\"type\"] = \"VALUE_TO_TEXT\"\n",
    "                    \n",
    "                    f.write(json.dumps(example) + \"\\n\")\n",
    "\n",
    "            logging.info(f\"{split_name} split saved to {output_file}\")\n",
    "\n",
    "        if self.delete_raw:\n",
    "            for p in self.dataset_root.iterdir():\n",
    "                if p.is_dir():\n",
    "                    shutil.rmtree(p)\n",
    "                elif '.jsonl' not in str(p.name):\n",
    "                    p.unlink()\n",
    "\n",
    "\n",
    "    @lru_cache\n",
    "    def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs):\n",
    "        # pylint: disable=C0115,C0116\n",
    "        return create_sft_dataset(\n",
    "            path,\n",
    "            tokenizer=self.tokenizer,\n",
    "            seq_length=(self.seq_length if is_test or self.packed_sequence_size <= 0 else self.packed_sequence_size),\n",
    "            memmap_workers=self.memmap_workers,\n",
    "            seed=self.seed,\n",
    "            chat=True,\n",
    "            is_test=is_test,\n",
    "            pack_metadata_file_path=None,  # packing is not supported\n",
    "            pad_cu_seqlens=False,\n",
    "            **kwargs,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After the dataloader \"BespokeDataModule\" is defined, we can configure it with the `run.Config` API. Given that long-CoT data typically contains longer sequences, we set seq_length=16384."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from bespoke import BespokeDataModule\n",
    "\n",
    "def bespoke() -> run.Config[pl.LightningDataModule]:\n",
    "    return run.Config(BespokeDataModule, seq_length=16384, micro_batch_size=1, global_batch_size=32, num_workers=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: Fine-tune Qwen with NeMo 2.0 and NeMo-Run.\n",
    "\n",
    "In this part, we will introduce how to fine-tune the Qwen model with NeMo 2.0 and NeMo-Run. We need to configure the SFT training components, including the trainer, logger, optimizer and model, then launch the training using the `llm.finetune` API.\n",
    "\n",
    "\n",
    "#### Step 3.1: Configure SFT with NeMo 2.0\n",
    "First, we need to configure the trainer, logger, optimizer, and other components."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "import nemo_run as run\n",
    "from nemo import lightning as nl\n",
    "from nemo.collections import llm\n",
    "from megatron.core.optimizer import OptimizerConfig\n",
    "import torch\n",
    "import lightning.pytorch as pl\n",
    "from pathlib import Path\n",
    "from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed\n",
    "\n",
    "\n",
    "# Configure the trainer\n",
    "# we use 4 GPUs for training and set the max_steps to 300.\n",
    "def trainer() -> run.Config[nl.Trainer]:\n",
    "    strategy = run.Config(\n",
    "        nl.MegatronStrategy,\n",
    "        tensor_model_parallel_size=4,\n",
    "    )\n",
    "    trainer = run.Config(\n",
    "        nl.Trainer,\n",
    "        devices=4,\n",
    "        max_steps=300,\n",
    "        accelerator=\"gpu\",\n",
    "        strategy=strategy,\n",
    "        plugins=bf16_mixed(),\n",
    "        log_every_n_steps=1,\n",
    "        limit_val_batches=0,\n",
    "        val_check_interval=0,\n",
    "        num_sanity_val_steps=0,\n",
    "    )\n",
    "    return trainer\n",
    "\n",
    "\n",
    "# Configure the logger\n",
    "# Here, we configure the log interval to 100 steps and save the model every 100 steps. you can change these parameters as needed.\n",
    "def logger() -> run.Config[nl.NeMoLogger]:\n",
    "    ckpt = run.Config(\n",
    "        nl.ModelCheckpoint,\n",
    "        save_last=True,\n",
    "        every_n_train_steps=100,\n",
    "        monitor=\"reduced_train_loss\",\n",
    "        save_top_k=1,\n",
    "        save_on_train_epoch_end=True,\n",
    "        save_optim_on_train_end=True,\n",
    "    )\n",
    "\n",
    "    return run.Config(\n",
    "        nl.NeMoLogger,\n",
    "        name=\"qwen_sft\",\n",
    "        log_dir=\"./results\",\n",
    "        use_datetime_version=False,\n",
    "        ckpt=ckpt,\n",
    "        wandb=None\n",
    "    )\n",
    "\n",
    "\n",
    "# Configure the optimizer\n",
    "# We use the distributed Adam optimizer and pass in the OptimizerConfig.\n",
    "def adam_with_cosine_annealing() -> run.Config[nl.OptimizerModule]:\n",
    "    opt_cfg = run.Config(\n",
    "        OptimizerConfig,\n",
    "        optimizer=\"adam\",\n",
    "        lr=2e-5,\n",
    "        adam_beta2=0.98,\n",
    "        use_distributed_optimizer=True,\n",
    "        clip_grad=1.0,\n",
    "        bf16=True,\n",
    "    )\n",
    "    return run.Config(\n",
    "        nl.MegatronOptimizerModule,\n",
    "        config=opt_cfg\n",
    "    )\n",
    "\n",
    "\n",
    "# Configure the model\n",
    "# We use Qwen2Config7B to configure the model.\n",
    "def qwen() -> run.Config[pl.LightningModule]:\n",
    "    return run.Config(llm.Qwen2Model, config=run.Config(llm.Qwen2Config7B))\n",
    "\n",
    "# Configure the resume\n",
    "def resume() -> run.Config[nl.AutoResume]:\n",
    "    return run.Config(\n",
    "        nl.AutoResume,\n",
    "        restore_config=run.Config(nl.RestoreConfig,\n",
    "            path=\"nemo://Qwen/Qwen2.5-7B-Instruct\"\n",
    "        ),\n",
    "        resume_if_exists=True,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 3.2: Configure NeMo 2.0 `llm.finetune` API\n",
    "\n",
    "To use the components defined above, we can call the `llm.finetune` API and pass the components as parameters to it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "def configure_finetuning_recipe():\n",
    "    return run.Partial(\n",
    "        llm.finetune,\n",
    "        model=qwen(),\n",
    "        trainer=trainer(),\n",
    "        data=bespoke(),\n",
    "        log=logger(),\n",
    "        optim=adam_with_cosine_annealing(),\n",
    "        resume=resume(),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 3.3: Launch Training\n",
    "\n",
    "To launch the training, we use `LocalExecutor` for executing our configured finetune function. In this example, we use 4 GPUs for training.\n",
    "\n",
    "For more details on the NeMo-Run executor, refer to [Execute NeMo Run](https://github.com/NVIDIA/NeMo-Run/blob/main/docs/source/guides/execution.md) of NeMo-Run Guides.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "def local_executor_torchrun(nodes: int = 1, devices: int = 4) -> run.LocalExecutor:\n",
    "    # Env vars for jobs are configured here\n",
    "    env_vars = {\n",
    "        \"TORCH_NCCL_AVOID_RECORD_STREAMS\": \"1\",\n",
    "        \"NCCL_NVLS_ENABLE\": \"0\",\n",
    "        \"NVTE_DP_AMAX_REDUCE_INTERVAL\": \"0\",\n",
    "        \"NVTE_ASYNC_AMAX_REDUCTION\": \"1\",\n",
    "    }\n",
    "\n",
    "    executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\", env_vars=env_vars)\n",
    "\n",
    "    return executor\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    run.run(configure_finetuning_recipe(), executor=local_executor_torchrun())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4. Evaluate Model\n",
    "\n",
    "This section demonstrates how to generate inference results using the trained checkpoint. In this example, we perform inference on a single GPU. You can modify the `tensor_model_parallel_size` and `local_executor_torchrun` functions to utilize multiple GPUs for speedup."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from megatron.core.inference.common_inference_params import CommonInferenceParams\n",
    "\n",
    "sft_ckpt_path=str(next((d for d in Path(\"./results/qwen_sft/checkpoints/\").iterdir() if d.is_dir() and d.name.endswith(\"-last\")), None))\n",
    "print(\"We will load SFT checkpoint from:\", sft_ckpt_path)\n",
    "\n",
    "\n",
    "\n",
    "def trainer() -> run.Config[nl.Trainer]:\n",
    "    strategy = run.Config(\n",
    "        nl.MegatronStrategy,\n",
    "        tensor_model_parallel_size=1,\n",
    "    )\n",
    "    trainer = run.Config(\n",
    "        nl.Trainer,\n",
    "        accelerator=\"gpu\",\n",
    "        devices=1,\n",
    "        num_nodes=1,\n",
    "        strategy=strategy,\n",
    "        plugins=bf16_mixed(),\n",
    "    )\n",
    "    return trainer\n",
    "\n",
    "prompts = [\n",
    "    \"How many r's are in the word 'strawberry'?\",\n",
    "    \"Which number is bigger? 10.119 or 10.19?\",\n",
    "]\n",
    "\n",
    "def configure_inference():\n",
    "    return run.Partial(\n",
    "        llm.generate,\n",
    "        path=str(sft_ckpt_path),\n",
    "        trainer=trainer(),\n",
    "        prompts=prompts,\n",
    "        inference_params=CommonInferenceParams(num_tokens_to_generate=16384, top_p=0.6),\n",
    "        output_path=\"sft_prediction.jsonl\",\n",
    "    )\n",
    "\n",
    "\n",
    "def local_executor_torchrun(nodes: int = 1, devices: int = 1) -> run.LocalExecutor:\n",
    "    # Env vars for jobs are configured here\n",
    "    env_vars = {\n",
    "        \"TORCH_NCCL_AVOID_RECORD_STREAMS\": \"1\",\n",
    "        \"NCCL_NVLS_ENABLE\": \"0\",\n",
    "        \"NVTE_DP_AMAX_REDUCE_INTERVAL\": \"0\",\n",
    "        \"NVTE_ASYNC_AMAX_REDUCTION\": \"1\",\n",
    "    }\n",
    "\n",
    "    executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\", env_vars=env_vars)\n",
    "\n",
    "    return executor\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    run.run(configure_inference(), executor=local_executor_torchrun())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 5. Convert Fine-tuned Model to HF Format.\n",
    "\n",
    "After training, you can convert the output model to hf format using the `llm.export_ckpt` API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "sft_ckpt_path=str(next((d for d in Path(\"./results/qwen_sft/checkpoints/\").iterdir() if d.is_dir() and d.name.endswith(\"-last\")), None))\n",
    "\n",
    "print(\"We will load SFT checkpoint from:\", sft_ckpt_path)\n",
    "\n",
    "# llm.export_ckpt is the nemo2 API for exporting a NeMo checkpoint to Hugging Face format\n",
    "# example python usage:\n",
    "# llm.export_ckpt(path=\"/path/to/model.nemo\", target=\"hf\", output_path=\"/path/to/save\")\n",
    "def configure_checkpoint_conversion():\n",
    "    return run.Partial(\n",
    "        llm.export_ckpt,\n",
    "        path=sft_ckpt_path,\n",
    "        target=\"hf\",\n",
    "        output_path=\"./model\"\n",
    "    )\n",
    "\n",
    "# configure your function\n",
    "export_ckpt = configure_checkpoint_conversion()\n",
    "# define your executor\n",
    "local_executor = run.LocalExecutor()\n",
    "\n",
    "# run your experiment\n",
    "run.run(export_ckpt, executor=local_executor)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
